{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "11abe3e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/cuda117/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([   1,    1,    0, 9178,   32,   47]),\n",
       " tensor([0, 0, 1, 1, 1, 1]),\n",
       " '<pad><pad><s>how are you')"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import random\n",
    "import math\n",
    "\n",
    "from util import TokenizerUtil\n",
    "\n",
    "tokenizer = TokenizerUtil()\n",
    "\n",
    "input_ids, _ = tokenizer.encode('how are you', max_length=6)\n",
    "\n",
    "input_ids, attention_mask = tokenizer.pad_to_left(input_ids)\n",
    "\n",
    "input_ids, attention_mask, tokenizer.decode(input_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5e7480ae",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.60s/it]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "\n",
    "model_actor = AutoModelForCausalLM.from_pretrained('model/rlhf',\n",
    "                                                   torch_dtype=torch.float16,\n",
    "                                                   device_map='cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fc404ebc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(8,\n",
       " {'prompt': 'Human: context= CREATE TABLE table_name_86 (rd VARCHAR, _time VARCHAR, date VARCHAR) question= What was the Rd. Time for October 3, 2009? Assistant:',\n",
       "  'chosen': 'SELECT rd, _time FROM table_name_86 WHERE date = \"october 3, 2009\"',\n",
       "  'rejected': '',\n",
       "  'response': 'SELECT rd, _time FROM table_name_86 WHERE date = \"october 3, 2009\"'})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "with open('dataset/eval.json') as f:\n",
    "    lines = f.readlines()\n",
    "\n",
    "lines = random.sample(lines, k=8)\n",
    "lines = [json.loads(i) for i in lines]\n",
    "\n",
    "len(lines), lines[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2ceb369d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "select RD FROM table_name_86 WHERE _time = \"october 3, 2009\" AND date = \"october 3, 2009\"</s>\n",
      "SELECT rd, _time FROM table_name_86 WHERE date = \"october 3, 2009\"\n",
      "===============\n",
      "select ENGINE from TABLE_NAME_27 where YEAR = \"61\" and POINTS = \"0\" and CHASSIS = \"porsche 718\"</s>\n",
      "SELECT engine FROM table_name_27 WHERE points = 0 AND chassis = \"porsche 718\" AND year = 1961\n",
      "===============\n",
      "select STATE from TABLE_NAME_43 where TEAM = \"MUMBAI\"</s>\n",
      "SELECT state FROM table_name_43 WHERE team = \"mumbai\"\n",
      "===============\n",
      "select MIN(2nd_baseman) from TABLE_12142298_2 where FIRST_BASEMAN = \"nomar gARCIAAPARRA\"</s>\n",
      "SELECT second_baseman FROM table_12142298_2 WHERE first_baseman = \"Nomar Garciaparra\"\n",
      "===============\n",
      "select SCIENCE_AGENCY from TABLE_NAME_17 where NAME = \"HERSChel Space Observatory\"</s>\n",
      "SELECT space_agency FROM table_name_17 WHERE name = \"herschel space observatory\"\n",
      "===============\n",
      "select DATE from TABLE_NAME_14 where PEARCE = \"Samba\"</s>\n",
      "SELECT pair FROM table_name_14 WHERE date = \"october 14, 2008\" AND dance = \"samba\"\n",
      "===============\n",
      "select POINTS from TABLE_NAME_96 where PLAYED = \"42\"</s>\n",
      "SELECT AVG(points) FROM table_name_96 WHERE played > 42\n",
      "===============\n",
      "select DRIVER from TABLE_NAME_60 where LAPS < 9 and GRID < 13</s>\n",
      "SELECT driver FROM table_name_60 WHERE laps < 9 AND grid = 13\n",
      "===============\n"
     ]
    }
   ],
   "source": [
    "for data in lines:\n",
    "    input_ids, _ = tokenizer.encode(data['prompt'], max_length=256)\n",
    "    input_ids, attention_mask = tokenizer.pad_to_left(input_ids)\n",
    "\n",
    "    input_ids = input_ids.unsqueeze(0).to('cuda')\n",
    "    attention_mask = attention_mask.unsqueeze(0).to('cuda')\n",
    "\n",
    "    generate = model_actor.generate(input_ids=input_ids,\n",
    "                                    attention_mask=attention_mask,\n",
    "                                    max_length=512,\n",
    "                                    pad_token_id=tokenizer.pad_token_id,\n",
    "                                    eos_token_id=tokenizer.eos_token_id)\n",
    "\n",
    "    generate = generate[0, 256:].to('cpu')\n",
    "\n",
    "    print(tokenizer.decode(generate))\n",
    "    print(data['chosen'])\n",
    "    print('===============')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:cuda117]",
   "language": "python",
   "name": "conda-env-cuda117-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
